import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import numpy as np
from skimage.metrics import structural_similarity, peak_signal_noise_ratio, mean_squared_error, normalized_root_mse
import data_creation


def logit(x):
    eps = 1e-2
    x = torch.clamp(x, eps, 1 - eps)
    return torch.log(x) - torch.log(torch.ones_like(x) - x)


def mahalanobis_clip(r_mu, mahalonobis_d, rep_dim=5, l1_clip=True):
    if rep_dim == 5:
        if l1_clip:
            # mahalonobis_d = 8.328  # empirical, for 5D gives 99.7% of points
            # mahalonobis_d = 6.369 #empirical, for 5D gives 95% of points
            # mahalonobis_d = 5.783898 #empirical, for 5D gives 90% of points
            condition = torch.sum(torch.abs(r_mu), dim=1) > mahalonobis_d
            # if condition.sum().item() >0:
            #     print('clipping {} images'.format(condition.sum().item()))
            clipped_values = (mahalonobis_d / torch.sum(torch.abs(r_mu[condition]), dim=1).reshape(-1, 1)) * r_mu[
                condition]
            r_mu[condition] = clipped_values

        else:
            # use an l2 clip
            mahalonobis_d = 4.25  # empirical, for 5D gives 99.7% of points
            condition = torch.norm(r_mu, dim=1) > mahalonobis_d
            # if condition.sum().item() >0:
            #     print('clipping {} images'.format(condition.sum().item()))
            clipped_values = (mahalonobis_d / torch.norm(r_mu[condition], dim=1).reshape(-1, 1)) * r_mu[condition]
            r_mu[condition] = clipped_values

        return r_mu
    else:
        raise NotImplementedError('Currently only implemented for rep_dim=5.')

def calculate_delta_f(data, task, dim_reduce=False):
    delta_f_proxy = torch.max(data, dim=0)[0] - torch.min(data, dim=0)[0]
    if task == 'MNIST' and not dim_reduce:
        delta_f = 2751.4854 # with logit
        # delta_f = 304.4582 # without logit

    else:
        delta_f = 0
        L = data.shape[0]

        for (i, img) in enumerate(data):
            max_norm = torch.max(torch.norm(data - img.unsqueeze(0), p=1, dim=-1))
            if max_norm > delta_f:
                delta_f = max_norm
            if i % 1000 == 0:
                print("{}/{} datapoints checked".format(i, L))

    return delta_f, delta_f_proxy

# def calculate_posterior_clips(data):
#     npdata=data.cpu().numpy()
#     l1norms = np.linalg.norm(npdata, axis=1,ord=1)
#     clip_dict = {}
#     for percentile in [.997, .975, .95, .90, .85, .80, .75]:
#         idx = int(percentile*len(l1norms))
#         clip_dict[percentile] = np.sort(l1norms)[idx]
#         print("L1 norm of "+str(100*percentile)+"th percentile is "+str(np.sort(l1norms)[idx]))
#     return clip_dict


def calculate_epsilon(scale, device, subset_size=None):
    data_transforms = [transforms.ToTensor(), logit]
    dataset = datasets.MNIST('../_datasets/mnist', train=True, download=True,
                             transform=transforms.Compose(data_transforms))
    L = len(dataset)
    loader = DataLoader(dataset, batch_size=L)
    data = next(iter(loader))[0]
    data = data.view(data.shape[0], -1).to(device)
    scale = scale.to(device)
    max_eps = 0
    for (i, img) in enumerate(data):
        iter_max = torch.max(torch.norm((data - img.unsqueeze(0)) / scale.view(1, -1), p=1, dim=-1))
        if iter_max > max_eps:
            max_eps = iter_max
        if i % 500 == 0:
            print("{}/{} datapoints checked".format(i, min(L, subset_size)))
        if subset_size is not None and i > subset_size:
            break
    return max_eps
    
    
def calculate_epsilon_from_noise(x_noise, y_noise, delta_f, n_cat=10, x_noise_type='Laplace'):
    if x_noise_type is not 'Laplace':
        raise NotImplementedError ("Only Laplace noise is implemented")
    if x_noise == 0.0 or y_noise == 0.0:
        return np.inf
    epsilon_x = delta_f / x_noise * np.sqrt(2)  # x_noise is the x_std = sqrt(2) * scale
    partial_epsilon_y = np.log((n_cat-1)*(1-y_noise)/y_noise)
    epsilon_y = np.maximum(partial_epsilon_y, -partial_epsilon_y)  # we assume we flip
    return epsilon_x + epsilon_y 
    
def calculate_xnoise_ynoise_from_epsilon(opt):
    
    epsilon = opt.epsilon
    x_noise_budget = opt.epsilon_split
    delta_f = opt.delta_f
    x_noise_type = opt.noise_type
    n_cat = opt.n_categories
    noise_features_directly = opt.noise_features_directly 
    ncat_of_cat_features = opt.ncat_of_cat_features
    no_label_noise = opt.data_join_task or opt.novel_class
    tabular = opt.tabular
    
    if not x_noise_type == 'Laplace':
        raise NotImplementedError("Only Laplace noise is implemented")

    if no_label_noise:
        y_noise = 0.
        epsilon_x = epsilon
    else:
        epsilon_x = epsilon * x_noise_budget
        epsilon_y = epsilon * (1 - x_noise_budget)
        y_noise = (n_cat - 1)/(n_cat - 1 + np.exp(epsilon_y))
    
    if noise_features_directly and tabular:
        total_number_of_features = len(delta_f)+len(ncat_of_cat_features)
        epsilon_x_single_feature = epsilon_x / total_number_of_features
        x_noise_cat = [] 
        x_noise_cont = np.sqrt(2) * delta_f / epsilon_x_single_feature
        for ncat in ncat_of_cat_features:
            x_noise_cat_single = (ncat - 1)/(ncat - 1 + np.exp(epsilon_x_single_feature))
            x_noise_cat.append(x_noise_cat_single)
        # now concatenate the two
        x_noise_cat = torch.Tensor(x_noise_cat)
        x_noise = torch.cat((x_noise_cont, x_noise_cat))# note that the x noise will be a tensor in this case
    else:
        x_noise = np.sqrt(2) * delta_f / epsilon_x

    return x_noise, y_noise


def calculate_image_metrics(data_loader, encoder, decoder, x_noise, md, device, noise_type='Laplace'):
    num_images_processed = 0
    ssim = 0.
    psnr = 0.
    mse = 0.
    nrmse = 0.
    image_metrics = {}

    with torch.no_grad():
        for batch_idx, (data, label) in enumerate(data_loader):
            clean_data = data.to(device)
            if encoder is not None:
                latents = encoder.get_data_representatation(data.to(device), data_loader=False, clip=md)
            if decoder is not None:
                noisy_data = decoder.get_data_reconstruction(latents.to(device), x_noise, clip=md)
            else:
                noisy_data = data_creation.add_data_noise(clean_data, x_noise, noise_type)

            for i in range(data.size(0)):
                num_images_processed+=1
                ssim += structural_similarity(clean_data[i][0].cpu().numpy(),noisy_data[i][0].cpu().numpy())
                psnr += peak_signal_noise_ratio(clean_data[i][0].cpu().numpy(),noisy_data[i][0].cpu().numpy(),data_range=10)
                mse += mean_squared_error(clean_data[i][0].cpu().numpy(),noisy_data[i][0].cpu().numpy())                
                nrmse += normalized_root_mse(clean_data[i][0].cpu().numpy(),noisy_data[i][0].cpu().numpy())
        image_metrics['ssim'] = ssim/num_images_processed
        image_metrics['psnr'] = psnr/num_images_processed
        image_metrics['mse'] = mse/num_images_processed
        image_metrics['nrmse'] = nrmse/num_images_processed
    return image_metrics
